#include "gmm_image_division.h"

GMM_Image_Division::GMM_Image_Division(QWidget *parent)
    : CommonGraphicsView{parent}
{
    this->setWindowTitle("GMM图像分割");
}


void GMM_Image_Division::dropEvent(QDropEvent *event){
    QString filePath = event->mimeData()->urls().at(0).toLocalFile();
    showGMMImageDivision(filePath.toStdString().c_str());
}


void GMM_Image_Division::showGMMImageDivision(const char *filePath){
    Mat src = imread(filePath);
    if(src.empty()){
        qDebug()<<"载入图像为空";
        return;
    }

    imshow("src",src);

    //获取原始图像的宽、高、通道数
    int width = src.cols;
    int height = src.rows;
    int dims = src.channels();

    //将原始图像数据转为double类型
    int numSamples = width*height;//总共的像素点个数
    Mat points(numSamples,dims,CV_64FC1);//存储转换后浮点数据的容器
    Mat labels;//分类标签
    //填充样本像素数据
    int index = 0;
    for(int row = 0;row<height;row++){
        for(int col = 0;col<width;col++){
            index = width*row+col;
            Vec3b rgb = src.at<Vec3b>(row,col);
            points.at<double>(index,0)= static_cast<int>(rgb[0]);
            points.at<double>(index,1)= static_cast<int>(rgb[1]);
            points.at<double>(index,2)= static_cast<int>(rgb[2]);
        }
    }

    //
    int numCluster = 3;//3分类
    Scalar clolors[] = {
        Scalar(255, 0, 0),
        Scalar(0, 255, 0),
        Scalar(0, 0, 255),
        Scalar(255, 255, 0)
    };
    //使用emm分类模型进行分类
    Ptr<EM> emModel = EM::create();
    emModel->setClustersNumber(numCluster);//设置分类个数
    emModel->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);//设置协方差矩阵模型
    emModel->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));//设置停止条件
    emModel->trainEM(points,noArray(),labels,noArray());//训练

    //根据标签分类
    Mat result = Mat::zeros(src.size(),CV_8UC3);

    for(int row=0;row<height;row++){
        for(int col=0;col<width;col++){
            index = row*width+col;
            int label = labels.at<int>(index,0);
            Scalar color = clolors[label];
            result.at<Vec3b>(row,col)[0] = color[0];
            result.at<Vec3b>(row,col)[1] = color[1];
            result.at<Vec3b>(row,col)[2] = color[2];
        }
    }

    imshow("result",result);

}

